


#####################################################
###### 05_additional_supplement
###### Other supplementary tables and figures 
#####################################################



rm(list = ls())
library(dplyr)
library(here)
library(stargazer)
library(xtable)
library(caret)
library(purrr)

#####################################################
###### Summarize LLM predictive performance 
#####################################################

# Load truth (hand coded) vs predicted 
train_preds <- readRDS(here("data/llm_results_train.RDS")) %>%
  filter(model == "claude-3-5-sonnet-20240620") %>%
  # Collapse N/A or valence unclear
  mutate(across(contains("valence"), ~ ifelse(. %in% c("N/A", "Other (salient but unclear)"), 
                                              "N/A or valence unclear", .)))
test_preds <- readRDS(here("data/llm_results_test.RDS")) %>%
  # Collapse N/A or valence unclear
  mutate(across(contains("valence"), ~ ifelse(. %in% c("N/A", "Other (salient but unclear)"), 
                                              "N/A or valence unclear", .)))

# Functions to produce key classification metrics 
compute_binary_metrics <- function(truth, prediction, positive = "1") {
  # Convert to factors with consistent levels
  truth <- factor(truth)
  prediction <- factor(prediction, levels = levels(truth))
  
  # Compute confusion matrix
  cm <- confusionMatrix(prediction, truth, positive = positive)
  
  # Extract metrics
  metrics <- cm$byClass[c("Precision", "Recall", "F1")]
  accuracy <- cm$overall["Accuracy"]
  
  # Combine metrics into a named vector
  c(Accuracy = accuracy,
    Precision = metrics["Precision"],
    Recall = metrics["Recall"],
    F1_Score = metrics["F1"])
}

compute_multiclass_metrics <- function(truth, prediction) {
  # Ensure factors have the same levels
  truth <- factor(truth)
  prediction <- factor(prediction, levels = levels(truth))
  
  # Compute confusion matrix
  cm <- confusionMatrix(prediction, truth)
  
  # Extract accuracy
  accuracy <- cm$overall['Accuracy']
  
  # Extract per-class metrics
  by_class <- cm$byClass
  
  if (is.matrix(by_class)) {
    precision <- by_class[,'Precision']
    recall <- by_class[,'Recall']
    f1 <- by_class[,'F1']
  } else {
    # In case of binary classification, byClass is not a matrix
    precision <- by_class['Precision']
    recall <- by_class['Recall']
    f1 <- by_class['F1']
  }
  
  # Compute Macro F1
  macro_f1 <- mean(f1, na.rm = TRUE)
  
  # Compute Weighted F1
  support <- table(truth)
  weighted_f1 <- sum(f1 * support) / sum(support)
  
  # Return as a named vector
  return(c(Accuracy = accuracy,
           Macro_F1 = macro_f1,
           Weighted_F1 = weighted_f1))
}

# Function to process classifiers and compute metrics
process_classifiers <- function(classifier_list, dataset, type = "binary") {
  map_df(names(classifier_list), function(classifier) {
    cols <- classifier_list[[classifier]]
    truth <- dataset[[cols$truth]]
    prediction <- dataset[[cols$prediction]]
    
    if (type == "binary") {
      metrics <- compute_binary_metrics(truth, prediction, positive = "1")
      tibble(
        Classifier = classifier,
        Accuracy = round(metrics["Accuracy.Accuracy"] * 100, 2),
        Precision = round(metrics["Precision.Precision"], 2),
        Recall = round(metrics["Recall.Recall"], 2),
        F1_Score = round(metrics["F1_Score.F1"], 2)
      )
    } else if (type == "multiclass") {
      metrics <- compute_multiclass_metrics(truth, prediction)
      tibble(
        Classifier = classifier,
        Accuracy = round(metrics["Accuracy.Accuracy"] * 100, 2),
        Macro_F1 = round(metrics["Macro_F1"], 2),
        Weighted_F1 = round(metrics["Weighted_F1"], 2)
      )
    }
  })
}

# Define binary classification tasks 
binary_classifiers <- list(
  "Salience of Targeting" = list(
    truth = "salience_target_truth",
    prediction = "salience_target_prediction"
  ),
  "Salience of Impersonality" = list(
    truth = "salience_impersonal_truth",
    prediction = "salience_impersonal_prediction"
  )
)

# Define multiclass classification tasks
multiclass_classifiers <- list(
  "Valence of Targeting" = list(
    truth = "valence_target_truth",
    prediction = "valence_target_prediction"
  ),
  "Valence of Impersonality" = list(
    truth = "valence_impersonal_truth",
    prediction = "valence_impersonal_prediction"
  )
)

# Compute metrics for training
salience_perform_train <- process_classifiers(binary_classifiers, train_preds, 
                                              type = "binary")
valence_perform_train <- process_classifiers(multiclass_classifiers, train_preds, 
                                             type = "multiclass")

# Compute metrics for test
salience_perform_test <- process_classifiers(binary_classifiers, test_preds, 
                                             type = "binary")
valence_perform_test <- process_classifiers(multiclass_classifiers, test_preds, 
                                            type = "multiclass")


salience_perform_train
valence_perform_train 
salience_perform_test
valence_perform_test

# Output
salience_combined <- salience_perform_train %>%
  rename_with(~ paste0(., "_Train"), -Classifier) %>%
  left_join(
    salience_perform_test %>%
      rename_with(~ paste0(., "_Test"), -Classifier),  
    by = "Classifier"
  )

valence_combined <- valence_perform_train %>%
  rename_with(~ paste0(., "_Train"), -Classifier) %>%  
  left_join(
    valence_perform_test %>%
      rename_with(~ paste0(., "_Test"), -Classifier), 
    by = "Classifier"
  )

print(xtable(
  salience_combined,
  caption = "Salience Classification Performance in Training and Test Set",
  label = "tab:salience_performance",
  size = "small"),
  type = "latex", booktabs = TRUE,
  include.rownames = FALSE,
  file = here("output/tables/salience_perform.tex")
  )

print(xtable(
  valence_combined,
  caption = "Valence Classification Performance in Training and Test Set",
  label = "tab:valence_performance",
  size = "small"),
  type = "latex", booktabs = TRUE,
  include.rownames = FALSE,
  file = here("output/tables/valence_perform.tex")
  )

